{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d2abd10e-e63e-4904-badf-5a16409503b1",
   "metadata": {},
   "source": [
    "# LoRA -- A Multilayer Perceptron Example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "263e27da-47c7-4030-83c6-bf5f7e8bef74",
   "metadata": {},
   "source": [
    "This code notebook illustrates how LoRA ([https://arxiv.org/abs/2106.09685](https://arxiv.org/abs/2106.09685)) works by implementing it from scratch in the context of a multilayer perceptron (not LLM) to illustrate it with a simple example.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c1c52f02-94fb-4f45-902e-79126e27347d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "629ec66a-eb81-40a5-ae3d-d5c1d2a7e390",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4ade5e86-8bd8-4a35-8db1-44451601b292",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████| 9912422/9912422 [00:01<00:00, 8702295.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████| 28881/28881 [00:00<00:00, 31844293.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 4452323.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 9097673.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([64])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "394e5da8-2978-40f0-bca7-b79e8e35734f",
   "metadata": {},
   "source": [
    "# Multilayer Perceptron Model (Without LoRA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7e905c42-7f59-4a08-b6c5-a10f99f33e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 123\n",
    "learning_rate = 0.005\n",
    "num_epochs = 2\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_hidden_1 = 128\n",
    "num_hidden_2 = 256\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "class MultilayerPerceptron(nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = nn.Sequential(\n",
    "            nn.Linear(num_features, num_hidden_1),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_hidden_1, num_hidden_2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_hidden_2, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.layers(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "torch.manual_seed(random_seed)\n",
    "model_pretrained = MultilayerPerceptron(\n",
    "    num_features=num_features,\n",
    "    num_hidden_1=num_hidden_1,\n",
    "    num_hidden_2=num_hidden_2, \n",
    "    num_classes=num_classes\n",
    ")\n",
    "\n",
    "model_pretrained.to(DEVICE)\n",
    "optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cf31624a-d950-402f-a564-2e7fb63db8a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.view(-1, 28*28).to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits = model(features)\n",
    "            _, predicted_labels = torch.max(logits, 1)\n",
    "            num_examples += targets.size(0)\n",
    "            correct_pred += (predicted_labels == targets).sum()\n",
    "        return correct_pred.float()/num_examples * 100\n",
    "\n",
    "\n",
    "def train(num_epochs, model, optimizer, train_loader, device):\n",
    "\n",
    "    start_time = time.time()\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "\n",
    "            features = features.view(-1, 28*28).to(device)\n",
    "            targets = targets.to(device)\n",
    "\n",
    "            # FORWARD AND BACK PROP\n",
    "            logits = model(features)\n",
    "            loss = F.cross_entropy(logits, targets)\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss.backward()\n",
    "\n",
    "            # UPDATE MODEL PARAMETERS\n",
    "            optimizer.step()\n",
    "\n",
    "            # LOGGING\n",
    "            if not batch_idx % 400:\n",
    "                print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'\n",
    "                      % (epoch+1, num_epochs, batch_idx,\n",
    "                          len(train_loader), loss))\n",
    "\n",
    "        with torch.set_grad_enabled(False):\n",
    "            print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "                  epoch+1, num_epochs,\n",
    "                  compute_accuracy(model, train_loader, device)))\n",
    "\n",
    "        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "\n",
    "    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f47cfe4e-65eb-440e-b922-17c6dee7d7e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/002 | Batch 000/938 | Loss: 2.2971\n",
      "Epoch: 001/002 | Batch 400/938 | Loss: 0.2258\n",
      "Epoch: 001/002 | Batch 800/938 | Loss: 0.1612\n",
      "Epoch: 001/002 training accuracy: 95.71%\n",
      "Time elapsed: 0.04 min\n",
      "Epoch: 002/002 | Batch 000/938 | Loss: 0.0593\n",
      "Epoch: 002/002 | Batch 400/938 | Loss: 0.0588\n",
      "Epoch: 002/002 | Batch 800/938 | Loss: 0.0556\n",
      "Epoch: 002/002 training accuracy: 97.40%\n",
      "Time elapsed: 0.08 min\n",
      "Total Training Time: 0.08 min\n",
      "Test accuracy: 96.55%\n"
     ]
    }
   ],
   "source": [
    "train(num_epochs, model_pretrained, optimizer_pretrained, train_loader, DEVICE)\n",
    "print(f'Test accuracy: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb3480b9-aea5-411e-b252-d7fc8a5dd21d",
   "metadata": {},
   "source": [
    "# Multilayer Perceptron with LoRA"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b9d281-22ba-4120-af95-f6b95adcaa03",
   "metadata": {},
   "source": [
    "## Modify model by injecting LoRA Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "215795c5-c0d4-4886-b4d6-a5a0e7cc8c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoRALayer(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, rank, alpha):\n",
    "        super().__init__()\n",
    "        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n",
    "        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n",
    "        self.B = nn.Parameter(torch.zeros(rank, out_dim))\n",
    "        self.alpha = alpha\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.alpha * (x @ self.A @ self.B)\n",
    "        return x\n",
    "\n",
    "\n",
    "class LinearWithLoRA(nn.Module):\n",
    "    def __init__(self, linear, rank, alpha):\n",
    "        super().__init__()\n",
    "        self.linear = linear\n",
    "        self.lora = LoRALayer(\n",
    "            linear.in_features, linear.out_features, rank, alpha\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x) + self.lora(x)\n",
    "\n",
    "    \n",
    "# This LoRA code is equivalent to LinearWithLoRA\n",
    "class LinearWithLoRAMerged(nn.Module):\n",
    "    def __init__(self, linear, rank, alpha):\n",
    "        super().__init__()\n",
    "        self.linear = linear\n",
    "        self.lora = LoRALayer(\n",
    "            linear.in_features, linear.out_features, rank, alpha\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        lora = self.lora.A @ self.lora.B\n",
    "        combined_weight = self.linear.weight + lora.T\n",
    "        return F.linear(x, combined_weight, self.linear.bias)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0441c93f-0ee5-4003-acc3-f24541f06c66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "layer = nn.Linear(10, 2)\n",
    "x = torch.randn((1, 10))\n",
    "\n",
    "print(\"Original output:\", layer(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b132184a-f87e-423f-850a-2dc44fe76770",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "layer_lora_1 = LinearWithLoRA(layer, rank=2, alpha=4)\n",
    "\n",
    "print(\"LoRA output:\", layer_lora_1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f555c364-9c5f-4a20-8c8c-feb830131555",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "layer_lora_2 = LinearWithLoRAMerged(layer, rank=2, alpha=4)\n",
    "print(\"LoRA output:\", layer_lora_2(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "dc66ffa1-5822-4833-b636-d3a8170e84a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultilayerPerceptron(\n",
       "  (layers): Sequential(\n",
       "    (0): Linear(in_features=784, out_features=128, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=128, out_features=256, bias=True)\n",
       "    (3): ReLU()\n",
       "    (4): Linear(in_features=256, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_pretrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b00a7e8f-09ff-499e-b593-3f6dac87f1bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "model_lora = copy.deepcopy(model_pretrained)\n",
    "model_dora = copy.deepcopy(model_pretrained)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b1e3ef6f-255c-4c71-9da5-7d06d8c439a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultilayerPerceptron(\n",
       "  (layers): Sequential(\n",
       "    (0): LinearWithLoRA(\n",
       "      (linear): Linear(in_features=784, out_features=128, bias=True)\n",
       "      (lora): LoRALayer()\n",
       "    )\n",
       "    (1): ReLU()\n",
       "    (2): LinearWithLoRA(\n",
       "      (linear): Linear(in_features=128, out_features=256, bias=True)\n",
       "      (lora): LoRALayer()\n",
       "    )\n",
       "    (3): ReLU()\n",
       "    (4): LinearWithLoRA(\n",
       "      (linear): Linear(in_features=256, out_features=10, bias=True)\n",
       "      (lora): LoRALayer()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)\n",
    "model_lora.layers[2] = LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)\n",
    "model_lora.layers[4] = LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)\n",
    "\n",
    "model_lora.to(DEVICE)\n",
    "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n",
    "model_lora"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9756742d-f574-400a-8d4e-cc55233df83c",
   "metadata": {},
   "source": [
    "We just initialized the LoRA layers but haven't trained the LoRA layers yet, so a model with and without initial LoRA weights should have the same predictive performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d2ac620b-2fdf-4b94-92bb-f8d00e640306",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy orig model: 96.55%\n",
      "Test accuracy LoRA model: 96.55%\n"
     ]
    }
   ],
   "source": [
    "print(f'Test accuracy orig model: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')\n",
    "print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ceed732-7989-4f01-a5a1-1036eb41512d",
   "metadata": {},
   "source": [
    "## Train model with LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a35d4c20-f754-4e82-85a1-ad19a30b3dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_linear_layers(model):\n",
    "    for child in model.children():\n",
    "        if isinstance(child, nn.Linear):\n",
    "            for param in child.parameters():\n",
    "                param.requires_grad = False\n",
    "        else:\n",
    "            # Recursively freeze linear layers in children modules\n",
    "            freeze_linear_layers(child)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "88454690-abe6-49de-986e-9a6fe7883000",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layers.0.linear.weight: False\n",
      "layers.0.linear.bias: False\n",
      "layers.0.lora.A: True\n",
      "layers.0.lora.B: True\n",
      "layers.2.linear.weight: False\n",
      "layers.2.linear.bias: False\n",
      "layers.2.lora.A: True\n",
      "layers.2.lora.B: True\n",
      "layers.4.linear.weight: False\n",
      "layers.4.linear.bias: False\n",
      "layers.4.lora.A: True\n",
      "layers.4.lora.B: True\n"
     ]
    }
   ],
   "source": [
    "freeze_linear_layers(model_lora)\n",
    "\n",
    "# Check if linear layers are frozen\n",
    "for name, param in model_lora.named_parameters():\n",
    "    print(f\"{name}: {param.requires_grad}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1b807c7b-8d4a-4a1e-8a56-42bbdbc82fed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/002 | Batch 000/938 | Loss: 0.0843\n",
      "Epoch: 001/002 | Batch 400/938 | Loss: 0.2096\n",
      "Epoch: 001/002 | Batch 800/938 | Loss: 0.0998\n",
      "Epoch: 001/002 training accuracy: 97.46%\n",
      "Time elapsed: 0.04 min\n",
      "Epoch: 002/002 | Batch 000/938 | Loss: 0.0840\n",
      "Epoch: 002/002 | Batch 400/938 | Loss: 0.1385\n",
      "Epoch: 002/002 | Batch 800/938 | Loss: 0.0056\n",
      "Epoch: 002/002 training accuracy: 97.79%\n",
      "Time elapsed: 0.07 min\n",
      "Total Training Time: 0.07 min\n",
      "Test accuracy LoRA finetune: 96.88%\n"
     ]
    }
   ],
   "source": [
    "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n",
    "train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)\n",
    "print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
